""" Skill Diffuser Implementation """

import time
from functools import partial
from typing import Any, Tuple, List, Dict, Union, Type, Optional, Callable

import gym
import numpy as np

import jax
import jax.numpy as jnp
import haiku as hk
import wandb

from sb3_jax.common.offline_algorithm import OfflineAlgorithm
from sb3_jax.common.buffers import BaseBuffer
from sb3_jax.common.type_aliases import GymEnv, MaybeCallback, Schedule
from sb3_jax.common.jax_utils import jax_print, jit_optimize, stop_grad

from diffgro.utils.utils import print_b
from diffgro.common.buffers import TrajectoryBuffer
from diffgro.common.models.utils import kl_div
from diffgro.experiments.skill_diffuser.policies import SkillDiffuserPolicy


class SkillDiffuser(OfflineAlgorithm):
    def __init__(
        self,
        policy: Union[str, Type[SkillDiffuserPolicy]],
        env: Union[GymEnv, str],
        replay_buffer: Type[BaseBuffer] = TrajectoryBuffer,
        learning_rate: float = 3e-4,
        batch_size: int = 256,
        gamma: float = 0.99,
        gradient_steps: int = 1,
        tensorboard_log: Optional[str] = None,
        wandb_log: Optional[str] = None,
        policy_kwargs: Optional[Dict[str, Any]] = None,
        verbose: int = 0,
        seed: Optional[int] = 1,
        _init_setup_model: bool = True,
    ):
        super(SkillDiffuser, self).__init__(
            policy,
            env,
            replay_buffer=replay_buffer,
            learning_rate=learning_rate,
            batch_size=batch_size,
            gamma=gamma,
            gradient_steps=gradient_steps,
            tensorboard_log=tensorboard_log,
            wandb_log=wandb_log,
            policy_kwargs=policy_kwargs,
            verbose=verbose,
            create_eval_env=False,
            seed=seed,
            _init_setup_model=False,
            supported_action_spaces=(gym.spaces.Box),
            support_multi_env=False,
        )
        self.learning_rate = learning_rate

        if _init_setup_model:
            self._setup_model()

    def _setup_model(self) -> None:
        self.set_random_seed(self.seed)
        self.policy = self.policy_class(
            self.observation_space,
            self.action_space,
            seed=self.seed,
            **self.policy_kwargs,
        )
        self._create_aliases()

    def _create_aliases(self) -> None:
        self.plan = self.policy.plan
        self.inv = self.policy.inv

    def train(self, gradient_steps: int, batch_size: int = 256) -> None:
        horizon = self.policy.horizon
        losses = {
            "total_loss": [],
            "inv_loss": [],
            "vq_loss": [],
            "diff_loss": [],
        }

        for gradient_step in range(gradient_steps):
            self._n_updates += 1
            obs, lang, act, next_obs = self._sample_replay_data(batch_size)

            self.inv.optim_state, self.inv.params, inv_loss, _ = jit_optimize(
                self.inv._compute_loss,
                self.inv.optim,
                self.inv.optim_state,
                self.inv.params,
                max_grad_norm=None,
                obs=obs,
                next_obs=next_obs,
                act=act,
                rng=next(self.inv.rng),
            )
            for ts in range(1, obs.shape[1], horizon):
                obs_skill_prd = obs[:, :ts]
                obs_cond_diff = obs[:, ts : ts + self.plan.plan_horizon]
                obs_pad = jnp.tile(
                    obs_cond_diff[:, -1][:, None],
                    (1, self.plan.plan_horizon - obs_cond_diff.shape[1], 1),
                )
                obs_cond_diff = jnp.concatenate([obs_cond_diff, obs_pad], axis=1)

                self.plan.optim_state, self.plan.params, _, plan_info = jit_optimize(
                    self.plan._compute_loss,
                    self.plan.optim,
                    self.plan.optim_state,
                    self.plan.params,
                    max_grad_norm=None,
                    obs_skill_prd=obs_skill_prd,
                    obs_cond_diff=obs_cond_diff,
                    lang=lang,
                    state=self.plan.state,
                    rng=next(self.plan.rng),
                )
                self.plan.state = plan_info["state"]
                print_b(f"total loss: {plan_info['total_loss']}")

                losses["total_loss"].append(plan_info["total_loss"])
                losses["vq_loss"].append(plan_info["vq_loss"])
                losses["diff_loss"].append(plan_info["diff_loss"])

            self.logger.record("train/total/loss", np.mean(losses["total_loss"]))
            self.logger.record("train/vq/loss", np.mean(losses["vq_loss"]))
            self.logger.record("train/diff/loss", np.mean(losses["diff_loss"]))
            self.logger.record("train/inv/loss", inv_loss)

            wandb_log = {
                "time/total_timesteps": self.num_timesteps,
                "train/loss/total": np.mean(losses["total_loss"]),
                "train/loss/vq": np.mean(losses["vq_loss"]),
                "train/loss/diff": np.mean(losses["diff_loss"]),
                "train/loss/inv": inv_loss,
            }
            wandb.log(wandb_log)

    def _sample_replay_data(
        self, batch_size: int
    ) -> Tuple[jax.Array, jax.Array, jax.Array]:
        batch_keys = ["tasks", "observations", "actions"]
        max_length = self.plan.plan_horizon + 1
        replay_data = self.replay_buffer.sample(
            batch_keys, batch_size, max_length=max_length
        )

        obs = replay_data.observations[:, :-1]
        lang = replay_data.tasks[:, jnp.newaxis]
        act = replay_data.actions[:, :-1]
        next_obs = replay_data.observations[:, 1:]

        return obs, lang, act, next_obs

    def learn(
        self,
        total_timesteps: Tuple[int, int],
        callback: MaybeCallback = None,
        log_interval: int = 1,
        eval_env: Optional[GymEnv] = None,
        eval_freq: int = -1,
        n_eval_episodes: int = 5,
        eval_log_path: Optional[str] = None,
        reset_num_timesteps: bool = True,
        tb_log_name: str = "SkillDiffuser",
    ) -> "SkillDiffuser":
        self.log_interval = log_interval
        total_timesteps, callback = self._setup_learn(
            total_timesteps,
            eval_env,
            callback,
            eval_freq,
            n_eval_episodes,
            eval_log_path,
            reset_num_timesteps,
            tb_log_name,
        )
        callback.on_training_start(locals(), globals())

        # 2. learn policy module
        start_time = time.time()
        num_timesteps = 0
        while num_timesteps < total_timesteps:
            self.train(gradient_steps=self.gradient_steps, batch_size=self.batch_size)

            self.num_timesteps += 1
            num_timesteps += 1
            if log_interval is not None and num_timesteps % log_interval == 0:
                fps = int(num_timesteps / (time.time() - start_time))
                self.logger.record("time/fps", fps)
                self.logger.record(
                    "time/time_elapsed",
                    int(time.time() - start_time),
                    exclude="tensorboard",
                )
                self.logger.record(
                    "time/total_timesteps", num_timesteps, exclude="tensorboard"
                )
                self.logger.dump(step=num_timesteps)

            callback.update_locals(locals())
            if callback.on_step() is False:
                return False

        callback.on_training_end()
        return self

    def predict(
        self,
        obs: jnp.ndarray,
        lang: jnp.ndarray,
    ) -> jnp.ndarray:

        if self.obs_stack is None:
            self.obs_stack = obs.reshape(1, 1, -1)
            noop_act = jnp.zeros_like(self.policy.action_space.sample())
            return noop_act

        self.obs_stack = jnp.concatenate(
            [self.obs_stack, obs.reshape(1, 1, -1)], axis=1
        )

        return self.policy._predict(self.obs_stack, lang)

    def reset(self):
        self.obs_stack = None
        self.policy.offset = 0

    def load_params(self, path: str) -> None:
        print_b(f"[skill diffuser] : loading params")
        data, params = load_from_zip_file(path, verbose=1)
        self._load_jax_params(params)
        self._load_norm_layer(path)

    def _save_jax_params(self) -> Dict[str, hk.Params]:
        params_dict = {}
        params_dict["plan_params"] = self.plan.params
        params_dict["inv_params"] = self.inv.params
        return params_dict

    def _load_jax_params(self, params: Dict[str, hk.Params]) -> None:
        self.plan._load_jax_params(params)
        self.inv._load_jax_params(params)

    def _save_norm_layer(self, path: str) -> None:
        if self.policy.normalization_class is not None:
            self.policy.normalization_layer.save(path)

    def _load_norm_layer(self, path: str) -> None:
        if self.policy.normalization_class is not None:
            self.policy.normalization_layer = self.policy.normalization_layer.load(path)
